import onnx

'''
为简化分析逻辑，此处只判断了模型中是否含有Conv+BN的组合，经过优化后的模型应将BN融合进Conv中，
用于测试对的原始模型需含有Conv+BN组合(这里用了Densenet121作为测试模型)。
'''

def analyze(onnxfile, **kwargs):
    model = onnx.load(onnxfile)

    dict_conv = {}
    dict_bn = {}
    dict_mul = {}

    merge_mish = False

    for node_id, node in enumerate(model.graph.node):
        if node.op_type == 'Conv':
            dict_conv['input'] = node.input
            dict_conv['output'] = node.output

        if node.op_type == 'BatchNormalization':
            if len(dict_conv) > 0 and node.input[0] == dict_conv['output'][0]:
                print('got conv+bn pair')
                return False
            else:
                print('clear dict_conv')
                dict_conv = {}    

    return True